Draft

Amortized Bayesian Inference with PyTorch

pytorch
variational auto-encoders
amortized bayesian inference
Published

July 25, 2025

Heuristics in Latent Space
The cost of generating new sample data can be prohibitive. There is a secondary but different cost which attaches to the ‘construction’ of novel data. Principal Components Analysis can be seen as a technique to optimally reconstruct a complex multivariate data set from a lower level compressed dimensional space. Variational auto-encoders allow us to achieve yet more flexible reconstruction results in non-linear cases. Drawing a new sample from the posterior predictive distribution of Bayesian models similarly supplies us with insight in the variability of realised data. Both methods assume a latent model of the data generating process that aims to leverage a compressed representation of the data. These are different heuristics with different consequences for how we understand the variability in the world. Amortized Bayesian inference seeks to unite the two heuristics.

Reconstruction Error

It’s natural to seek short cuts

import torch
import torchvision.datasets as dsets
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pymc as pm 

Job Satisfaction Data

import numpy as np

# Standard deviations
stds = np.array([0.939, 1.017, 0.937, 0.562, 0.760, 0.524, 
                 0.585, 0.609, 0.731, 0.711, 1.124, 1.001])

n = len(stds)

# Lower triangular correlation values as a flat list
corr_values = [
    1.000,
    .668, 1.000,
    .635, .599, 1.000,
    .263, .261, .164, 1.000,
    .290, .315, .247, .486, 1.000,
    .207, .245, .231, .251, .449, 1.000,
   -.206, -.182, -.195, -.309, -.266, -.142, 1.000,
   -.280, -.241, -.238, -.344, -.305, -.230,  .753, 1.000,
   -.258, -.244, -.185, -.255, -.255, -.215,  .554,  .587, 1.000,
    .080,  .096,  .094, -.017,  .151,  .141, -.074, -.111,  .016, 1.000,
    .061,  .028, -.035, -.058, -.051, -.003, -.040, -.040, -.018,  .284, 1.000,
    .113,  .174,  .059,  .063,  .138,  .044, -.119, -.073, -.084,  .563,  .379, 1.000
]

# Fill correlation matrix
corr_matrix = np.zeros((n, n))
idx = 0
for i in range(n):
    for j in range(i+1):
        corr_matrix[i, j] = corr_values[idx]
        corr_matrix[j, i] = corr_values[idx]
        idx += 1

# Covariance matrix: Sigma = D * R * D
cov_matrix = np.outer(stds, stds) * corr_matrix
#cov_matrix_test = np.dot(np.dot(np.diag(stds), corr_matrix), np.diag(stds))
columns=["JW1","JW2","JW3", "UF1","UF2","FOR", "DA1","DA2","DA3", "EBA","ST","MI"]
corr_df = pd.DataFrame(corr_matrix, columns=columns)

cov_df = pd.DataFrame(cov_matrix, columns=columns)
cov_df

def make_sample(cov_matrix, size, columns):
    sample_df = pd.DataFrame(np.random.multivariate_normal([0]*12, cov_matrix, size=size), columns=columns)
    return sample_df

sample_df = make_sample(cov_matrix, 263, columns)
sample_df.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 0.330438 0.326503 0.583599 0.298558 1.112858 0.371895 0.247663 0.704001 0.739605 0.317825 1.546412 1.282066
1 0.000730 -0.598124 -0.882404 0.034036 -0.220583 -0.443250 0.452083 0.976292 1.460018 0.208454 -0.237027 0.124897
2 -0.334853 -0.171359 -0.862147 -0.696685 0.294389 -0.671320 0.023049 -0.117460 0.394511 0.769453 1.138158 0.216388
3 -0.765717 0.554349 0.062522 0.181065 0.609657 0.595781 -0.056995 -0.635932 -0.330862 -0.424084 -0.548190 -0.637716
4 -0.912169 -0.369919 -0.210114 0.185822 -0.755927 -0.490341 -0.472250 -0.419797 0.123084 1.157508 -0.840009 -0.144719
data = sample_df.corr()

def plot_heatmap(data, title="Correlation Matrix",  vmin=-.2, vmax=.2, ax=None, figsize=(10, 6), colorbar=True):
    data_matrix = data.values
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(data, cmap='viridis', vmin=vmin, vmax=vmax)

    for i in range(data_matrix.shape[0]):
        for j in range(data_matrix.shape[1]):
            text = ax.text(
                j, i,                      # x, y coordinates
                f"{data_matrix[i, j]:.2f}",       # text to display
                ha="center", va="center",  # center alignment
                color="white" if data_matrix[i,j] < 0.5 else "black"  # contrast color
            )

    ax.set_title(title)
    ax.set_xticklabels(data.columns)  
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticklabels(data.index)  
    ax.set_yticks(np.arange(data.shape[0]))
    if colorbar:
        plt.colorbar(im)

plot_heatmap(data, vmin=-1, vmax=1)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

X = make_sample(cov_matrix, 100, columns=columns)
U, S, VT = np.linalg.svd(X, full_matrices=False)
ranks = [2, 5, 12]
reconstructions = []
for k in ranks:
    X_k = U[:, :k] @ np.diag(S[:k]) @ VT[:k, :]
    reconstructions.append(X_k)

# Plot original and reconstructed matrices
fig, axes = plt.subplots(1, len(ranks) + 1, figsize=(10,15))
axes[0].imshow(X, cmap='viridis')
axes[0].set_title("Original")
axes[0].axis("off")

for ax, k, X_k in zip(axes[1:], ranks, reconstructions):
    ax.imshow(X_k, cmap='viridis')
    ax.set_title(f"Rank {k}")
    ax.axis("off")

plt.suptitle("Reconstruction of Data Using SVD \n various truncation options",fontsize=12, x=.5, y=1.01)
plt.tight_layout()
plt.show()

Variational Auto-Encoders

class NumericVAE(nn.Module):
    def __init__(self, n_features, hidden_dim=64, latent_dim=8):
        super().__init__()
        
        # ---------- ENCODER ----------
        # First layer: compress input features into a hidden representation
        self.fc1 = nn.Linear(n_features, hidden_dim)
        
        # Latent space parameters (q(z|x)): mean and log-variance
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)       # μ(x)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)   # log(σ^2(x))
        
        # ---------- DECODER ----------
        # First layer: map latent variable z back into hidden representation
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        
        # Output distribution parameters for reconstruction p(x|z)
        # For numeric data, we predict both mean and log-variance per feature
        self.fc_out_mu = nn.Linear(hidden_dim, n_features)        # μ_x(z)
        self.fc_out_logvar = nn.Linear(hidden_dim, n_features)    # log(σ^2_x(z))

    # ENCODER forward pass: input x -> latent mean, log-variance
    def encode(self, x):
        h = F.relu(self.fc1(x))       # Hidden layer with ReLU
        mu = self.fc_mu(h)            # Latent mean vector
        logvar = self.fc_logvar(h)    # Latent log-variance vector
        return mu, logvar

    # Reparameterization trick: sample z = μ + σ * ε  (ε ~ N(0,1))
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)   # σ = exp(0.5 * logvar)
        eps = torch.randn_like(std)     # ε ~ N(0, I)
        return mu + eps * std           # z = μ + σ * ε

    # DECODER forward pass: latent z -> reconstructed mean, log-variance
    def decode(self, z):
        h = F.relu(self.fc2(z))             # Hidden layer with ReLU
        recon_mu = self.fc_out_mu(h)        # Mean of reconstructed features
        recon_logvar = self.fc_out_logvar(h)# Log-variance of reconstructed features
        return recon_mu, recon_logvar

    # Full forward pass: input x -> reconstructed (mean, logvar), latent params
    def forward(self, x):
        mu, logvar = self.encode(x)            # q(z|x)
        z = self.reparameterize(mu, logvar)    # Sample z from q(z|x)
        recon_mu, recon_logvar = self.decode(z)# p(x|z)
        return (recon_mu, recon_logvar), mu, logvar

    # Sample new synthetic data: z ~ N(0,I), decode to x
    def generate(self, n_samples=100):
        self.eval()
        with torch.no_grad():
            # Sample z from standard normal prior
            z = torch.randn(n_samples, self.fc_mu.out_features)
            
            # Decode to get reconstruction distribution parameters
            cont_mu, cont_logvar = self.decode(z)
            
            # Sample from reconstructed Gaussian: μ_x + σ_x * ε
            return cont_mu + torch.exp(0.5 * cont_logvar) * torch.randn_like(cont_mu)
def vae_loss(recon_mu, recon_logvar, x, mu, logvar):
    # Reconstruction loss: Gaussian log likelihood
    recon_var = torch.exp(recon_logvar)
    recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x - recon_mu) ** 2 / recon_var)
    recon_loss = recon_nll.sum(dim=1).mean()  # sum over features, mean over batch

    # KL divergence: D_KL(q(z|x) || p(z)) where p(z)=N(0,I)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()

    return recon_loss + kl_loss, recon_loss, kl_loss
def prep_data_vae(sample_size=1000):
    sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)

    X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)

    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)

    train_loader = torch.utils.data.DataLoader(X_train, batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(X_test, batch_size=32)
    return train_loader, test_loader


# | output: false

def train_vae(vae, optimizer, train, test, patience=30, wait=10, n_epochs=1000):
    best_loss = float('inf')
    losses = []

    for epoch in range(n_epochs):
        vae.train()
        train_loss = 0.0
        
        for batch in train:
            optimizer.zero_grad()

            (recon_mu, recon_logvar), mu, logvar = vae(batch)
            loss, recon_loss, kl_loss = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)

            loss.backward()
            optimizer.step()

            train_loss += loss.item() * batch.size(0)

        avg_train_loss = train_loss / train.dataset.shape[0]

        # --- Test Loop ---
        vae.eval()
        test_loss = 0.0
        with torch.no_grad():
            for batch in test:
                (recon_mu, recon_logvar), mu, logvar = vae(batch)
                loss, _, _ = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)
                test_loss += loss.item() * batch.size(0)
        avg_test_loss = test_loss / test.dataset.shape[0]

        print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")

        if test_loss < best_loss - 1e-4:
            best_loss, wait = test_loss, 0
            best_state = vae.state_dict()
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                vae.load_state_dict(best_state)  # restore best
                break
        losses.append([avg_train_loss, avg_test_loss, best_loss])

    return vae, pd.DataFrame(losses, columns=['train_loss', 'test_loss', 'best_loss'])

train_500, test_500 = prep_data_vae(500)
vae = NumericVAE(n_features=train_500.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_500, losses_df_500 = train_vae(vae, optimizer, train_500, test_500)


train_1000, test_1000 = prep_data_vae(1000)
vae = NumericVAE(n_features=train_1000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_1000, losses_df_1000 = train_vae(vae, optimizer, train_1000, test_1000)

train_10_000, test_10_000 = prep_data_vae(10_000)
vae = NumericVAE(n_features=train_10_000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_10_000, losses_df_10_000 = train_vae(vae, optimizer, train_10_000, test_10_000)
Epoch 1/1000 | Train Loss: 15.3680 | Test Loss: 14.4466
Epoch 2/1000 | Train Loss: 14.6294 | Test Loss: 14.0013
Epoch 3/1000 | Train Loss: 14.4412 | Test Loss: 13.8137
Epoch 4/1000 | Train Loss: 14.2084 | Test Loss: 13.6134
Epoch 5/1000 | Train Loss: 14.0166 | Test Loss: 13.5724
Epoch 6/1000 | Train Loss: 13.8730 | Test Loss: 13.2442
Epoch 7/1000 | Train Loss: 13.7800 | Test Loss: 13.1401
Epoch 8/1000 | Train Loss: 13.6590 | Test Loss: 13.2363
Epoch 9/1000 | Train Loss: 13.5496 | Test Loss: 12.9563
Epoch 10/1000 | Train Loss: 13.4809 | Test Loss: 13.1903
Epoch 11/1000 | Train Loss: 13.4523 | Test Loss: 12.8897
Epoch 12/1000 | Train Loss: 13.4471 | Test Loss: 13.0796
Epoch 13/1000 | Train Loss: 13.3620 | Test Loss: 12.9527
Epoch 14/1000 | Train Loss: 13.3164 | Test Loss: 12.8384
Epoch 15/1000 | Train Loss: 13.4159 | Test Loss: 12.8698
Epoch 16/1000 | Train Loss: 13.3627 | Test Loss: 12.9657
Epoch 17/1000 | Train Loss: 13.3054 | Test Loss: 12.8151
Epoch 18/1000 | Train Loss: 13.2753 | Test Loss: 12.9287
Epoch 19/1000 | Train Loss: 13.2456 | Test Loss: 12.9137
Epoch 20/1000 | Train Loss: 13.0975 | Test Loss: 12.8506
Epoch 21/1000 | Train Loss: 13.0006 | Test Loss: 12.7456
Epoch 22/1000 | Train Loss: 13.0891 | Test Loss: 12.7760
Epoch 23/1000 | Train Loss: 13.0707 | Test Loss: 12.5124
Epoch 24/1000 | Train Loss: 13.0744 | Test Loss: 12.7329
Epoch 25/1000 | Train Loss: 13.0097 | Test Loss: 12.5453
Epoch 26/1000 | Train Loss: 13.1350 | Test Loss: 12.6985
Epoch 27/1000 | Train Loss: 13.0438 | Test Loss: 12.6420
Epoch 28/1000 | Train Loss: 12.9861 | Test Loss: 12.5609
Epoch 29/1000 | Train Loss: 12.9717 | Test Loss: 12.7911
Epoch 30/1000 | Train Loss: 12.9789 | Test Loss: 12.6480
Epoch 31/1000 | Train Loss: 12.9025 | Test Loss: 12.4954
Epoch 32/1000 | Train Loss: 12.8808 | Test Loss: 12.3960
Epoch 33/1000 | Train Loss: 12.6912 | Test Loss: 12.5765
Epoch 34/1000 | Train Loss: 12.7194 | Test Loss: 12.4289
Epoch 35/1000 | Train Loss: 12.7199 | Test Loss: 12.7059
Epoch 36/1000 | Train Loss: 12.7432 | Test Loss: 12.8295
Epoch 37/1000 | Train Loss: 12.5384 | Test Loss: 12.5140
Epoch 38/1000 | Train Loss: 12.7536 | Test Loss: 12.4514
Epoch 39/1000 | Train Loss: 12.5349 | Test Loss: 12.2500
Epoch 40/1000 | Train Loss: 12.5994 | Test Loss: 12.1384
Epoch 41/1000 | Train Loss: 12.5682 | Test Loss: 12.2923
Epoch 42/1000 | Train Loss: 12.4889 | Test Loss: 12.4402
Epoch 43/1000 | Train Loss: 12.5832 | Test Loss: 12.3492
Epoch 44/1000 | Train Loss: 12.6140 | Test Loss: 12.4188
Epoch 45/1000 | Train Loss: 12.3821 | Test Loss: 12.6913
Epoch 46/1000 | Train Loss: 12.4215 | Test Loss: 12.0890
Epoch 47/1000 | Train Loss: 12.3097 | Test Loss: 12.0461
Epoch 48/1000 | Train Loss: 12.2807 | Test Loss: 12.2536
Epoch 49/1000 | Train Loss: 12.4938 | Test Loss: 12.2436
Epoch 50/1000 | Train Loss: 12.3915 | Test Loss: 12.1382
Epoch 51/1000 | Train Loss: 12.3161 | Test Loss: 12.1779
Epoch 52/1000 | Train Loss: 12.3445 | Test Loss: 12.2356
Epoch 53/1000 | Train Loss: 12.3148 | Test Loss: 12.2953
Epoch 54/1000 | Train Loss: 12.2382 | Test Loss: 12.2845
Epoch 55/1000 | Train Loss: 12.2473 | Test Loss: 12.2716
Epoch 56/1000 | Train Loss: 12.1182 | Test Loss: 12.5098
Epoch 57/1000 | Train Loss: 12.4766 | Test Loss: 11.8399
Epoch 58/1000 | Train Loss: 12.1878 | Test Loss: 12.0880
Epoch 59/1000 | Train Loss: 12.2158 | Test Loss: 12.2003
Epoch 60/1000 | Train Loss: 12.1366 | Test Loss: 12.0985
Epoch 61/1000 | Train Loss: 12.2614 | Test Loss: 12.1348
Epoch 62/1000 | Train Loss: 12.2824 | Test Loss: 12.1235
Epoch 63/1000 | Train Loss: 12.2330 | Test Loss: 12.0664
Epoch 64/1000 | Train Loss: 12.2218 | Test Loss: 11.7363
Epoch 65/1000 | Train Loss: 12.1137 | Test Loss: 12.1480
Epoch 66/1000 | Train Loss: 12.1029 | Test Loss: 12.1612
Epoch 67/1000 | Train Loss: 12.0378 | Test Loss: 11.9580
Epoch 68/1000 | Train Loss: 12.2043 | Test Loss: 12.0151
Epoch 69/1000 | Train Loss: 12.0640 | Test Loss: 11.5841
Epoch 70/1000 | Train Loss: 12.0710 | Test Loss: 11.9260
Epoch 71/1000 | Train Loss: 12.0873 | Test Loss: 11.9467
Epoch 72/1000 | Train Loss: 12.2275 | Test Loss: 12.0173
Epoch 73/1000 | Train Loss: 12.0214 | Test Loss: 12.2340
Epoch 74/1000 | Train Loss: 12.1030 | Test Loss: 12.1632
Epoch 75/1000 | Train Loss: 11.9975 | Test Loss: 12.1081
Epoch 76/1000 | Train Loss: 12.0733 | Test Loss: 11.9612
Epoch 77/1000 | Train Loss: 12.0680 | Test Loss: 11.9965
Epoch 78/1000 | Train Loss: 11.9915 | Test Loss: 11.5750
Epoch 79/1000 | Train Loss: 12.1035 | Test Loss: 11.6641
Epoch 80/1000 | Train Loss: 12.0573 | Test Loss: 11.8880
Epoch 81/1000 | Train Loss: 12.0454 | Test Loss: 11.8869
Epoch 82/1000 | Train Loss: 12.2185 | Test Loss: 11.8352
Epoch 83/1000 | Train Loss: 11.8416 | Test Loss: 11.7722
Epoch 84/1000 | Train Loss: 11.9458 | Test Loss: 12.0839
Epoch 85/1000 | Train Loss: 11.9343 | Test Loss: 11.8382
Epoch 86/1000 | Train Loss: 11.7793 | Test Loss: 11.6629
Epoch 87/1000 | Train Loss: 11.9376 | Test Loss: 11.8517
Epoch 88/1000 | Train Loss: 11.9058 | Test Loss: 11.9674
Epoch 89/1000 | Train Loss: 11.9810 | Test Loss: 11.9530
Epoch 90/1000 | Train Loss: 11.9375 | Test Loss: 11.9583
Epoch 91/1000 | Train Loss: 11.9050 | Test Loss: 11.8016
Epoch 92/1000 | Train Loss: 11.9151 | Test Loss: 11.8442
Epoch 93/1000 | Train Loss: 11.9765 | Test Loss: 11.6780
Epoch 94/1000 | Train Loss: 11.7764 | Test Loss: 11.6647
Epoch 95/1000 | Train Loss: 11.9397 | Test Loss: 11.7466
Epoch 96/1000 | Train Loss: 11.8824 | Test Loss: 11.7991
Epoch 97/1000 | Train Loss: 11.8992 | Test Loss: 11.6441
Epoch 98/1000 | Train Loss: 11.9894 | Test Loss: 11.7137
Epoch 99/1000 | Train Loss: 11.8923 | Test Loss: 11.8582
Epoch 100/1000 | Train Loss: 11.9387 | Test Loss: 11.7990
Epoch 101/1000 | Train Loss: 11.8489 | Test Loss: 11.6433
Epoch 102/1000 | Train Loss: 11.8347 | Test Loss: 11.7863
Epoch 103/1000 | Train Loss: 11.8149 | Test Loss: 11.8709
Epoch 104/1000 | Train Loss: 11.8872 | Test Loss: 11.6314
Epoch 105/1000 | Train Loss: 11.9860 | Test Loss: 11.6282
Epoch 106/1000 | Train Loss: 11.7904 | Test Loss: 11.5144
Epoch 107/1000 | Train Loss: 11.9036 | Test Loss: 11.8842
Epoch 108/1000 | Train Loss: 11.8875 | Test Loss: 11.7332
Epoch 109/1000 | Train Loss: 11.9403 | Test Loss: 11.5999
Epoch 110/1000 | Train Loss: 11.7477 | Test Loss: 11.8332
Epoch 111/1000 | Train Loss: 11.7686 | Test Loss: 11.8716
Epoch 112/1000 | Train Loss: 11.8730 | Test Loss: 11.5982
Epoch 113/1000 | Train Loss: 11.8873 | Test Loss: 11.8883
Epoch 114/1000 | Train Loss: 11.8040 | Test Loss: 11.7196
Epoch 115/1000 | Train Loss: 11.8358 | Test Loss: 11.8426
Epoch 116/1000 | Train Loss: 11.8142 | Test Loss: 11.6931
Epoch 117/1000 | Train Loss: 11.8364 | Test Loss: 11.6850
Epoch 118/1000 | Train Loss: 11.7723 | Test Loss: 11.6787
Epoch 119/1000 | Train Loss: 11.8987 | Test Loss: 11.9505
Epoch 120/1000 | Train Loss: 11.8757 | Test Loss: 11.6119
Epoch 121/1000 | Train Loss: 11.8819 | Test Loss: 11.5585
Epoch 122/1000 | Train Loss: 11.8183 | Test Loss: 11.6900
Epoch 123/1000 | Train Loss: 11.8411 | Test Loss: 11.7078
Epoch 124/1000 | Train Loss: 11.8336 | Test Loss: 11.4010
Epoch 125/1000 | Train Loss: 11.8570 | Test Loss: 11.7128
Epoch 126/1000 | Train Loss: 11.7205 | Test Loss: 11.6107
Epoch 127/1000 | Train Loss: 11.8519 | Test Loss: 11.7438
Epoch 128/1000 | Train Loss: 11.7875 | Test Loss: 11.5455
Epoch 129/1000 | Train Loss: 11.6375 | Test Loss: 11.6376
Epoch 130/1000 | Train Loss: 11.8661 | Test Loss: 11.6423
Epoch 131/1000 | Train Loss: 11.7535 | Test Loss: 11.5839
Epoch 132/1000 | Train Loss: 11.9020 | Test Loss: 11.6695
Epoch 133/1000 | Train Loss: 11.7003 | Test Loss: 11.7084
Epoch 134/1000 | Train Loss: 11.7411 | Test Loss: 11.3230
Epoch 135/1000 | Train Loss: 11.7531 | Test Loss: 11.9448
Epoch 136/1000 | Train Loss: 11.9063 | Test Loss: 11.8530
Epoch 137/1000 | Train Loss: 11.6648 | Test Loss: 11.7637
Epoch 138/1000 | Train Loss: 11.8565 | Test Loss: 11.3914
Epoch 139/1000 | Train Loss: 11.6975 | Test Loss: 11.6894
Epoch 140/1000 | Train Loss: 11.7427 | Test Loss: 11.7183
Epoch 141/1000 | Train Loss: 11.8284 | Test Loss: 11.5880
Epoch 142/1000 | Train Loss: 11.8494 | Test Loss: 11.6434
Epoch 143/1000 | Train Loss: 11.7410 | Test Loss: 11.6091
Epoch 144/1000 | Train Loss: 11.8402 | Test Loss: 11.8378
Epoch 145/1000 | Train Loss: 11.7116 | Test Loss: 11.8754
Epoch 146/1000 | Train Loss: 11.8941 | Test Loss: 11.7678
Epoch 147/1000 | Train Loss: 11.7115 | Test Loss: 11.8005
Epoch 148/1000 | Train Loss: 11.7632 | Test Loss: 11.8169
Epoch 149/1000 | Train Loss: 11.7536 | Test Loss: 11.7337
Epoch 150/1000 | Train Loss: 11.8170 | Test Loss: 11.6149
Epoch 151/1000 | Train Loss: 11.8265 | Test Loss: 12.0937
Epoch 152/1000 | Train Loss: 11.7586 | Test Loss: 11.6340
Epoch 153/1000 | Train Loss: 11.7092 | Test Loss: 11.5450
Epoch 154/1000 | Train Loss: 11.8161 | Test Loss: 11.7703
Epoch 155/1000 | Train Loss: 11.6751 | Test Loss: 11.8580
Epoch 156/1000 | Train Loss: 11.7677 | Test Loss: 11.8970
Epoch 157/1000 | Train Loss: 11.7456 | Test Loss: 11.7207
Epoch 158/1000 | Train Loss: 11.8477 | Test Loss: 11.5301
Epoch 159/1000 | Train Loss: 11.7751 | Test Loss: 11.6020
Epoch 160/1000 | Train Loss: 11.6670 | Test Loss: 11.7262
Epoch 161/1000 | Train Loss: 11.6201 | Test Loss: 11.7380
Epoch 162/1000 | Train Loss: 11.5950 | Test Loss: 11.7042
Epoch 163/1000 | Train Loss: 11.8083 | Test Loss: 11.6820
Epoch 164/1000 | Train Loss: 11.8261 | Test Loss: 11.8257
Early stopping at epoch 164
Epoch 1/1000 | Train Loss: 15.3143 | Test Loss: 14.5206
Epoch 2/1000 | Train Loss: 14.5441 | Test Loss: 14.2272
Epoch 3/1000 | Train Loss: 14.3164 | Test Loss: 14.0423
Epoch 4/1000 | Train Loss: 14.1134 | Test Loss: 13.9909
Epoch 5/1000 | Train Loss: 14.0792 | Test Loss: 13.8330
Epoch 6/1000 | Train Loss: 13.7882 | Test Loss: 13.6578
Epoch 7/1000 | Train Loss: 13.6744 | Test Loss: 13.7155
Epoch 8/1000 | Train Loss: 13.5305 | Test Loss: 13.5384
Epoch 9/1000 | Train Loss: 13.5408 | Test Loss: 13.4030
Epoch 10/1000 | Train Loss: 13.5270 | Test Loss: 13.5321
Epoch 11/1000 | Train Loss: 13.4275 | Test Loss: 13.3616
Epoch 12/1000 | Train Loss: 13.3409 | Test Loss: 13.4895
Epoch 13/1000 | Train Loss: 13.2326 | Test Loss: 13.1596
Epoch 14/1000 | Train Loss: 13.1596 | Test Loss: 13.1739
Epoch 15/1000 | Train Loss: 13.0184 | Test Loss: 13.1595
Epoch 16/1000 | Train Loss: 12.8999 | Test Loss: 12.8433
Epoch 17/1000 | Train Loss: 12.8725 | Test Loss: 12.9558
Epoch 18/1000 | Train Loss: 12.8969 | Test Loss: 12.7351
Epoch 19/1000 | Train Loss: 12.7585 | Test Loss: 12.6521
Epoch 20/1000 | Train Loss: 12.6703 | Test Loss: 12.5746
Epoch 21/1000 | Train Loss: 12.5997 | Test Loss: 12.6308
Epoch 22/1000 | Train Loss: 12.4540 | Test Loss: 12.5919
Epoch 23/1000 | Train Loss: 12.5542 | Test Loss: 12.4800
Epoch 24/1000 | Train Loss: 12.4466 | Test Loss: 12.5020
Epoch 25/1000 | Train Loss: 12.3637 | Test Loss: 12.2970
Epoch 26/1000 | Train Loss: 12.3854 | Test Loss: 12.6369
Epoch 27/1000 | Train Loss: 12.4289 | Test Loss: 12.2421
Epoch 28/1000 | Train Loss: 12.3948 | Test Loss: 12.5283
Epoch 29/1000 | Train Loss: 12.2724 | Test Loss: 12.5595
Epoch 30/1000 | Train Loss: 12.3463 | Test Loss: 12.5177
Epoch 31/1000 | Train Loss: 12.2532 | Test Loss: 12.5354
Epoch 32/1000 | Train Loss: 12.3209 | Test Loss: 12.3344
Epoch 33/1000 | Train Loss: 12.3379 | Test Loss: 12.3521
Epoch 34/1000 | Train Loss: 12.3143 | Test Loss: 12.2789
Epoch 35/1000 | Train Loss: 12.3133 | Test Loss: 12.5243
Epoch 36/1000 | Train Loss: 12.2432 | Test Loss: 12.4296
Epoch 37/1000 | Train Loss: 12.2789 | Test Loss: 12.4316
Epoch 38/1000 | Train Loss: 12.2282 | Test Loss: 12.3293
Epoch 39/1000 | Train Loss: 12.2502 | Test Loss: 12.4584
Epoch 40/1000 | Train Loss: 12.3091 | Test Loss: 12.4258
Epoch 41/1000 | Train Loss: 12.1638 | Test Loss: 12.3796
Epoch 42/1000 | Train Loss: 12.2763 | Test Loss: 12.3078
Epoch 43/1000 | Train Loss: 12.2656 | Test Loss: 12.4695
Epoch 44/1000 | Train Loss: 12.3021 | Test Loss: 12.4233
Epoch 45/1000 | Train Loss: 12.1426 | Test Loss: 12.2583
Epoch 46/1000 | Train Loss: 12.2719 | Test Loss: 12.2178
Epoch 47/1000 | Train Loss: 12.2593 | Test Loss: 12.2725
Epoch 48/1000 | Train Loss: 12.2383 | Test Loss: 12.3514
Epoch 49/1000 | Train Loss: 12.1181 | Test Loss: 12.4072
Epoch 50/1000 | Train Loss: 12.0836 | Test Loss: 12.3616
Epoch 51/1000 | Train Loss: 12.2459 | Test Loss: 12.2764
Epoch 52/1000 | Train Loss: 12.1774 | Test Loss: 12.2319
Epoch 53/1000 | Train Loss: 12.1756 | Test Loss: 12.3363
Epoch 54/1000 | Train Loss: 12.0998 | Test Loss: 12.3563
Epoch 55/1000 | Train Loss: 12.0769 | Test Loss: 12.1846
Epoch 56/1000 | Train Loss: 12.1658 | Test Loss: 12.1814
Epoch 57/1000 | Train Loss: 12.2584 | Test Loss: 12.2048
Epoch 58/1000 | Train Loss: 12.1298 | Test Loss: 12.2736
Epoch 59/1000 | Train Loss: 12.1608 | Test Loss: 12.2898
Epoch 60/1000 | Train Loss: 12.0239 | Test Loss: 12.3782
Epoch 61/1000 | Train Loss: 12.1477 | Test Loss: 12.0479
Epoch 62/1000 | Train Loss: 12.1227 | Test Loss: 12.4227
Epoch 63/1000 | Train Loss: 12.1196 | Test Loss: 12.2449
Epoch 64/1000 | Train Loss: 12.0207 | Test Loss: 12.3184
Epoch 65/1000 | Train Loss: 12.0844 | Test Loss: 12.4369
Epoch 66/1000 | Train Loss: 12.1261 | Test Loss: 12.3580
Epoch 67/1000 | Train Loss: 12.0345 | Test Loss: 12.2135
Epoch 68/1000 | Train Loss: 12.0888 | Test Loss: 12.2756
Epoch 69/1000 | Train Loss: 12.1119 | Test Loss: 12.2135
Epoch 70/1000 | Train Loss: 11.9869 | Test Loss: 12.2734
Epoch 71/1000 | Train Loss: 12.0929 | Test Loss: 12.0073
Epoch 72/1000 | Train Loss: 12.1204 | Test Loss: 12.1544
Epoch 73/1000 | Train Loss: 11.9838 | Test Loss: 12.1363
Epoch 74/1000 | Train Loss: 12.1360 | Test Loss: 12.2834
Epoch 75/1000 | Train Loss: 11.9912 | Test Loss: 12.2601
Epoch 76/1000 | Train Loss: 12.0332 | Test Loss: 12.1388
Epoch 77/1000 | Train Loss: 12.0875 | Test Loss: 12.1467
Epoch 78/1000 | Train Loss: 11.9644 | Test Loss: 12.3583
Epoch 79/1000 | Train Loss: 12.0888 | Test Loss: 12.2773
Epoch 80/1000 | Train Loss: 12.0712 | Test Loss: 12.3570
Epoch 81/1000 | Train Loss: 12.0366 | Test Loss: 12.1873
Epoch 82/1000 | Train Loss: 12.0446 | Test Loss: 12.2192
Epoch 83/1000 | Train Loss: 11.9588 | Test Loss: 12.1476
Epoch 84/1000 | Train Loss: 12.0591 | Test Loss: 12.1768
Epoch 85/1000 | Train Loss: 11.9590 | Test Loss: 12.2443
Epoch 86/1000 | Train Loss: 11.9610 | Test Loss: 12.2786
Epoch 87/1000 | Train Loss: 11.9929 | Test Loss: 12.4902
Epoch 88/1000 | Train Loss: 11.9258 | Test Loss: 12.2453
Epoch 89/1000 | Train Loss: 11.9784 | Test Loss: 12.1693
Epoch 90/1000 | Train Loss: 12.0173 | Test Loss: 12.1154
Epoch 91/1000 | Train Loss: 11.9839 | Test Loss: 12.1429
Epoch 92/1000 | Train Loss: 11.8322 | Test Loss: 12.2297
Epoch 93/1000 | Train Loss: 11.9215 | Test Loss: 12.2400
Epoch 94/1000 | Train Loss: 12.0014 | Test Loss: 12.0487
Epoch 95/1000 | Train Loss: 12.0156 | Test Loss: 12.1343
Epoch 96/1000 | Train Loss: 11.9593 | Test Loss: 12.3031
Epoch 97/1000 | Train Loss: 11.9641 | Test Loss: 12.1161
Epoch 98/1000 | Train Loss: 12.0123 | Test Loss: 12.0265
Epoch 99/1000 | Train Loss: 11.9532 | Test Loss: 12.3249
Epoch 100/1000 | Train Loss: 11.8445 | Test Loss: 12.3431
Epoch 101/1000 | Train Loss: 12.0422 | Test Loss: 12.0286
Early stopping at epoch 101
Epoch 1/1000 | Train Loss: 13.8512 | Test Loss: 13.3184
Epoch 2/1000 | Train Loss: 13.1148 | Test Loss: 12.8622
Epoch 3/1000 | Train Loss: 12.6885 | Test Loss: 12.4904
Epoch 4/1000 | Train Loss: 12.3525 | Test Loss: 12.2722
Epoch 5/1000 | Train Loss: 12.1988 | Test Loss: 12.0703
Epoch 6/1000 | Train Loss: 12.0832 | Test Loss: 12.0370
Epoch 7/1000 | Train Loss: 12.0707 | Test Loss: 12.0223
Epoch 8/1000 | Train Loss: 12.0249 | Test Loss: 11.9396
Epoch 9/1000 | Train Loss: 12.0421 | Test Loss: 11.8700
Epoch 10/1000 | Train Loss: 11.9588 | Test Loss: 11.9392
Epoch 11/1000 | Train Loss: 11.9607 | Test Loss: 11.9116
Epoch 12/1000 | Train Loss: 11.9420 | Test Loss: 11.8527
Epoch 13/1000 | Train Loss: 11.9383 | Test Loss: 11.9190
Epoch 14/1000 | Train Loss: 11.9168 | Test Loss: 11.8795
Epoch 15/1000 | Train Loss: 11.8959 | Test Loss: 11.8604
Epoch 16/1000 | Train Loss: 11.9284 | Test Loss: 11.9318
Epoch 17/1000 | Train Loss: 11.8907 | Test Loss: 11.8615
Epoch 18/1000 | Train Loss: 11.8946 | Test Loss: 11.8777
Epoch 19/1000 | Train Loss: 11.9446 | Test Loss: 11.9216
Epoch 20/1000 | Train Loss: 11.8811 | Test Loss: 11.7708
Epoch 21/1000 | Train Loss: 11.8997 | Test Loss: 11.9013
Epoch 22/1000 | Train Loss: 11.9081 | Test Loss: 11.8816
Epoch 23/1000 | Train Loss: 11.8824 | Test Loss: 11.7657
Epoch 24/1000 | Train Loss: 11.8692 | Test Loss: 11.8659
Epoch 25/1000 | Train Loss: 11.8944 | Test Loss: 11.8023
Epoch 26/1000 | Train Loss: 11.8753 | Test Loss: 11.8261
Epoch 27/1000 | Train Loss: 11.8978 | Test Loss: 11.8922
Epoch 28/1000 | Train Loss: 11.8536 | Test Loss: 11.8470
Epoch 29/1000 | Train Loss: 11.8996 | Test Loss: 11.7920
Epoch 30/1000 | Train Loss: 11.8790 | Test Loss: 11.8041
Epoch 31/1000 | Train Loss: 11.8674 | Test Loss: 11.8703
Epoch 32/1000 | Train Loss: 11.8784 | Test Loss: 11.7797
Epoch 33/1000 | Train Loss: 11.8858 | Test Loss: 11.7773
Epoch 34/1000 | Train Loss: 11.8610 | Test Loss: 11.7589
Epoch 35/1000 | Train Loss: 11.8815 | Test Loss: 11.8217
Epoch 36/1000 | Train Loss: 11.8812 | Test Loss: 11.8198
Epoch 37/1000 | Train Loss: 11.8621 | Test Loss: 11.8226
Epoch 38/1000 | Train Loss: 11.8507 | Test Loss: 11.8185
Epoch 39/1000 | Train Loss: 11.8386 | Test Loss: 11.8318
Epoch 40/1000 | Train Loss: 11.8662 | Test Loss: 11.8078
Epoch 41/1000 | Train Loss: 11.9042 | Test Loss: 11.8097
Epoch 42/1000 | Train Loss: 11.8820 | Test Loss: 11.7710
Epoch 43/1000 | Train Loss: 11.8605 | Test Loss: 11.7969
Epoch 44/1000 | Train Loss: 11.8812 | Test Loss: 11.8130
Epoch 45/1000 | Train Loss: 11.8390 | Test Loss: 11.7780
Epoch 46/1000 | Train Loss: 11.8473 | Test Loss: 11.7617
Epoch 47/1000 | Train Loss: 11.8773 | Test Loss: 11.7419
Epoch 48/1000 | Train Loss: 11.8183 | Test Loss: 11.8625
Epoch 49/1000 | Train Loss: 11.8523 | Test Loss: 11.8417
Epoch 50/1000 | Train Loss: 11.8363 | Test Loss: 11.8311
Epoch 51/1000 | Train Loss: 11.8356 | Test Loss: 11.8375
Epoch 52/1000 | Train Loss: 11.8816 | Test Loss: 11.8135
Epoch 53/1000 | Train Loss: 11.8676 | Test Loss: 11.8157
Epoch 54/1000 | Train Loss: 11.8574 | Test Loss: 11.7945
Epoch 55/1000 | Train Loss: 11.8510 | Test Loss: 11.8232
Epoch 56/1000 | Train Loss: 11.8683 | Test Loss: 11.7542
Epoch 57/1000 | Train Loss: 11.8507 | Test Loss: 11.8625
Epoch 58/1000 | Train Loss: 11.8392 | Test Loss: 11.7640
Epoch 59/1000 | Train Loss: 11.8549 | Test Loss: 11.8906
Epoch 60/1000 | Train Loss: 11.8517 | Test Loss: 11.7967
Epoch 61/1000 | Train Loss: 11.8375 | Test Loss: 11.8566
Epoch 62/1000 | Train Loss: 11.8186 | Test Loss: 11.8473
Epoch 63/1000 | Train Loss: 11.8239 | Test Loss: 11.8126
Epoch 64/1000 | Train Loss: 11.8511 | Test Loss: 11.7618
Epoch 65/1000 | Train Loss: 11.8414 | Test Loss: 11.7448
Epoch 66/1000 | Train Loss: 11.8294 | Test Loss: 11.8077
Epoch 67/1000 | Train Loss: 11.8539 | Test Loss: 11.7916
Epoch 68/1000 | Train Loss: 11.8658 | Test Loss: 11.7878
Epoch 69/1000 | Train Loss: 11.8615 | Test Loss: 11.8276
Epoch 70/1000 | Train Loss: 11.8254 | Test Loss: 11.7756
Epoch 71/1000 | Train Loss: 11.8323 | Test Loss: 11.8589
Epoch 72/1000 | Train Loss: 11.7861 | Test Loss: 11.7638
Epoch 73/1000 | Train Loss: 11.8139 | Test Loss: 11.7869
Epoch 74/1000 | Train Loss: 11.8221 | Test Loss: 11.8388
Epoch 75/1000 | Train Loss: 11.8876 | Test Loss: 11.8327
Epoch 76/1000 | Train Loss: 11.8523 | Test Loss: 11.8629
Epoch 77/1000 | Train Loss: 11.8545 | Test Loss: 11.8116
Early stopping at epoch 77
fig, axs = plt.subplots(1, 3, figsize=(8, 6))
axs=axs.flatten()
losses_df_500[['train_loss', 'test_loss']].plot(ax=axs[0])
losses_df_1000[['train_loss', 'test_loss']].plot(ax=axs[1])
losses_df_10_000[['train_loss', 'test_loss']].plot(ax=axs[2])

axs[0].set_title("Training and Test Losses \n 500 observations");
axs[1].set_title("Training and Test Losses \n 1000 observations");
axs[2].set_title("Training and Test Losses \n 10_000 observations");

def bootstrap_residuals(vae_fit, X_test, sample_df, n_boot=1000):
    recons = []
    resid_array = np.zeros((n_boot, len(sample_df.columns), len(sample_df.columns)))
    for i in range(n_boot):
        recon_data = vae_fit.generate(n_samples=len(X_test))
        reconstructed_df = pd.DataFrame(recon_data, columns=sample_df.columns)
        resid = pd.DataFrame(X_test, columns=sample_df.columns).corr() - reconstructed_df.corr()
        resid_array[i] = resid.values
        recons.append(reconstructed_df)

    avg_resid = resid_array.mean(axis=0)
    bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df.columns, index=sample_df.columns)
    return bootstrapped_resids

bootstrapped_resids_500 = bootstrap_residuals(vae_fit_500, pd.DataFrame(test_500.dataset, columns=sample_df.columns), sample_df)

bootstrapped_resids_1000 = bootstrap_residuals(vae_fit_1000, pd.DataFrame(test_1000.dataset, columns=sample_df.columns), sample_df)

bootstrapped_resids_10_000 = bootstrap_residuals(vae_fit_10_000, pd.DataFrame(test_10_000.dataset, columns=sample_df.columns), sample_df)


fig, axs = plt.subplots(3, 1, figsize=(10, 20))
axs = axs.flatten()
plot_heatmap(bootstrapped_resids_500, title="""Expected Correlation Residuals for 500 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[0], colorbar=True, vmin=-.25, vmax=.25)

plot_heatmap(bootstrapped_resids_1000, title="""Expected Correlation  Residuals  for 1000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[1], colorbar=True, vmin=-.25, vmax=.25)

plot_heatmap(bootstrapped_resids_10_000, title="""Expected Correlation  Residuals  for 10,000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[2], colorbar=True, vmin=-.25, vmax=.25)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Missing Data

sample_df_missing = sample_df.copy()

# Randomly pick 5% of the total elements
mask_remove = np.random.rand(*sample_df_missing.shape) < 0.05

# Set those elements to NaN
sample_df_missing[mask_remove] = np.nan
sample_df_missing.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 0.330438 0.326503 0.583599 0.298558 1.112858 0.371895 0.247663 0.704001 0.739605 NaN 1.546412 1.282066
1 NaN -0.598124 -0.882404 0.034036 -0.220583 -0.443250 0.452083 0.976292 1.460018 0.208454 -0.237027 0.124897
2 -0.334853 -0.171359 -0.862147 -0.696685 0.294389 -0.671320 0.023049 -0.117460 0.394511 0.769453 1.138158 0.216388
3 -0.765717 0.554349 0.062522 0.181065 0.609657 0.595781 -0.056995 -0.635932 -0.330862 -0.424084 -0.548190 NaN
4 NaN -0.369919 -0.210114 0.185822 -0.755927 -0.490341 -0.472250 -0.419797 0.123084 1.157508 -0.840009 -0.144719
import torch
import torch.nn as nn
import torch.nn.functional as F


class MissingDataDataset(Dataset):
    def __init__(self, x, mask):
        # x and mask are tensors of same shape
        self.x = x
        self.mask = mask
        
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, idx):
        return self.x[idx], self.mask[idx]

def prep_data_vae_missing(sample_size=1000, batch_size=32):
    sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)

    X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)

    # Mask: 1=observed, 0=missing
    mask_train = ~pd.DataFrame(X_train).isna()
    mask_test = ~pd.DataFrame(X_test).isna()

    # Tensors (keep NaNs for missing values)
    x_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    mask_train_tensor = torch.tensor(mask_train.values, dtype=torch.float32)

    x_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    mask_test_tensor = torch.tensor(mask_test.values, dtype=torch.float32)

    train_dataset = MissingDataDataset(x_train_tensor, mask_train_tensor)
    test_dataset = MissingDataDataset(x_test_tensor, mask_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def vae_loss_missing(recon_mu, recon_logvar, x_filled, mu, logvar, mask):
    """
    VAE loss that skips missing values (NaNs) in x for the reconstruction term.
    """

    # Reconstruction loss (Gaussian NLL) only on observed values
    recon_var = torch.exp(recon_logvar)
    recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x_filled - recon_mu) ** 2 / recon_var)

    # Apply mask and normalize by number of observed features per sample
    recon_nll = recon_nll * mask  # zero-out missing features
    obs_counts = mask.sum(dim=1).clamp(min=1)  # avoid division by 0
    recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()

    # KL divergence as usual
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()

    return recon_loss, kl_loss
import torch
import torch.nn as nn
import torch.nn.functional as F

class NumericVAE_missing(nn.Module):
    def __init__(self, n_features, hidden_dim=64, latent_dim=8):
        super().__init__()
        self.n_features = n_features

        # ---------- Learnable Imputation ----------
        # One learnable parameter per feature for missing values
        self.missing_embeddings = nn.Parameter(torch.zeros(n_features))

        # ---------- ENCODER ----------
        self.fc1_x = nn.Linear(n_features, hidden_dim)

        # Stronger mask encoder: 2-layer MLP
        self.fc1_mask = nn.Sequential(
            nn.Linear(n_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Combine feature and mask embeddings
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # ---------- DECODER ----------
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc_out_mu = nn.Linear(hidden_dim, n_features)
        self.fc_out_logvar = nn.Linear(hidden_dim, n_features)

    def encode(self, x, mask):
        # Impute missing values with learnable parameters
        x_filled = torch.where(
            torch.isnan(x),
            self.missing_embeddings.expand_as(x),
            x
        )

        # Encode features and mask separately
        h_x = F.relu(self.fc1_x(x_filled))
        h_mask = self.fc1_mask(mask)

        # Combine embeddings
        h = h_x + h_mask

        # Latent space
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        recon_mu = self.fc_out_mu(h)
        recon_logvar = self.fc_out_logvar(h)
        return recon_mu, recon_logvar

    def forward(self, x, mask):
        mu, logvar = self.encode(x, mask)
        z = self.reparameterize(mu, logvar)
        recon_mu, recon_logvar = self.decode(z)
        return (recon_mu, recon_logvar), mu, logvar

    def generate(self, n_samples=100):
        self.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, self.fc_mu.out_features)
            recon_mu, recon_logvar = self.decode(z)
            return recon_mu + torch.exp(0.5 * recon_logvar) * torch.randn_like(recon_mu)
def vae_loss_missing(recon_mu, recon_logvar, x, mu, logvar, mask):
    # Fill missing values with 0 just for loss computation
    x_filled = torch.where(mask.bool(), x, torch.zeros_like(x))

    recon_var = torch.exp(recon_logvar)
    recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) +
                       (x_filled - recon_mu) ** 2 / recon_var)

    # Mask out missing values
    recon_nll = recon_nll * mask
    obs_counts = mask.sum(dim=1).clamp(min=1)
    recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()

    # KL divergence
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()

    return recon_loss, kl_loss

def beta_annealing(epoch, max_beta=1.0, anneal_epochs=100):

    beta = min(max_beta, max_beta * epoch / anneal_epochs)
    return beta
train_loader, test_loader = prep_data_vae_missing(10_000, batch_size=32)

vae_missing = NumericVAE_missing(n_features=next(iter(train_loader))[0].shape[1])
optimizer = optim.Adam(vae_missing.parameters(), lr=1e-4)

best_loss = float('inf')
patience, wait = 30, 0
losses = []

n_epochs = 1000
for epoch in range(n_epochs):
    beta = beta_annealing(epoch, max_beta=1.0, anneal_epochs=10)
    vae_missing.train()
    
    train_loss = 0
    for x_batch, mask_batch in train_loader:
        optimizer.zero_grad()
        (recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
        recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
        loss = recon_loss + beta * kl_loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x_batch.size(0)

    avg_train_loss = train_loss / len(train_loader.dataset)

    # --- Validation ---
    vae_missing.eval()
    test_loss = 0.0
    with torch.no_grad():
        for x_batch, mask_batch in test_loader:
            (recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
            recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
            loss = recon_loss + kl_loss
            test_loss += loss.item() * x_batch.size(0)
    avg_test_loss = test_loss / len(test_loader.dataset)

    print(f"Epoch {epoch+1}/{n_epochs} | Train: {avg_train_loss:.4f} | Test: {avg_test_loss:.4f}")

    # Early stopping
    if test_loss < best_loss - 1e-4:
            best_loss, wait = test_loss, 0
            best_state = vae_missing.state_dict()
    else:
        wait += 1
        if wait >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            vae_missing.load_state_dict(best_state)  # restore best
            break
    losses.append([avg_train_loss, avg_test_loss, best_loss])
bootstrapped_resids_500 = bootstrap_residuals(vae_missing, pd.DataFrame(test_loader.dataset.x, columns=sample_df.columns), sample_df)

plot_heatmap(bootstrapped_resids_500)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

recons = []
n_boot = 500
resid_array = np.zeros((n_boot, len(sample_df_missing.columns), len(sample_df_missing.columns)))
for i in range(500):
    recon_data = vae_missing.generate(n_samples=len(sample_df_missing))
    reconstructed_df = pd.DataFrame(recon_data, columns=sample_df_missing.columns)
    resid = pd.DataFrame(test_loader.dataset.x, columns=sample_df_missing.columns).corr() - reconstructed_df.corr()
    resid_array[i] = resid.values
    recons.append(reconstructed_df)

avg_resid = resid_array.mean(axis=0)
bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df_missing.columns, index=sample_df_missing.columns)

plot_heatmap(bootstrapped_resids, title="""Expected Residuals \n Under Bootstrapped Reconstructions""")
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

recon_data = vae_missing.generate(n_samples=len(sample_df_missing))

# Rebuild imputed DataFrame
imputed_array = sample_df_missing.to_numpy().copy()
imputed_array[mask_remove] = recon_data[mask_remove]
imputed_df = pd.DataFrame(imputed_array, columns=sample_df_missing.columns)

imputed_df.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 0.330438 0.326503 0.583599 0.298558 1.112858 0.371895 0.247663 0.704001 0.739605 -0.210920 1.546412 1.282066
1 -0.592279 -0.598124 -0.882404 0.034036 -0.220583 -0.443250 0.452083 0.976292 1.460018 0.208454 -0.237027 0.124897
2 -0.334853 -0.171359 -0.862147 -0.696685 0.294389 -0.671320 0.023049 -0.117460 0.394511 0.769453 1.138158 0.216388
3 -0.765717 0.554349 0.062522 0.181065 0.609657 0.595781 -0.056995 -0.635932 -0.330862 -0.424084 -0.548190 2.183827
4 0.946249 -0.369919 -0.210114 0.185822 -0.755927 -0.490341 -0.472250 -0.419797 0.123084 1.157508 -0.840009 -0.144719
plot_heatmap(sample_df.corr() - imputed_df.corr())
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

fig, axs = plt.subplots(1,2 ,figsize=(9, 30))
axs = axs.flatten()
plot_heatmap(sample_df_missing.head(50).fillna(99), vmin=-0, vmax=99, ax=axs[0], colorbar=False)
axs[0].set_title("Missng Data", fontsize=12)
plot_heatmap(imputed_df.head(50), vmin=-2, vmax=2, ax=axs[1], colorbar=False)
axs[1].set_title("Imputed Data", fontsize=12);
plt.tight_layout()
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Bayesian Inference

def make_pymc_model(sample_df):
    coords = {'features': sample_df.columns,
            'features1': sample_df.columns ,
            'obs': range(len(sample_df))}

    with pm.Model(coords=coords) as model:
        # Priors
        mus = pm.Normal("mus", 0, 1, dims='features')
        chol, _, _ = pm.LKJCholeskyCov("chol", n=12, eta=1.0, sd_dist=pm.HalfNormal.dist(1))
        cov = pm.Deterministic('cov', pm.math.dot(chol, chol.T), dims=('features', 'features1'))

        pm.MvNormal('likelihood', mus, cov=cov, observed=sample_df.values, dims=('obs', 'features'))
        
        idata = pm.sample_prior_predictive()
        idata.extend(pm.sample(random_seed=120))
        pm.sample_posterior_predictive(idata, extend_inferencedata=True)

    return idata, model 

idata, model = make_pymc_model(sample_df)
pm.model_to_graphviz(model)

import arviz as az

expected_corr = pd.DataFrame(az.summary(idata, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)

resids = sample_df.corr() - expected_corr
plot_heatmap(resids)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Missing Data

idata_missing, model_missing = make_pymc_model(sample_df_missing)
pm.model_to_graphviz(model_missing)

expected_corr = pd.DataFrame(az.summary(idata_missing, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)

resids = sample_df.corr() - expected_corr
plot_heatmap(resids)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Citation

BibTeX citation:
@online{forde2025,
  author = {Forde, Nathaniel},
  title = {Amortized {Bayesian} {Inference} with {PyTorch}},
  date = {2025-07-25},
  langid = {en},
  abstract = {The cost of generating new sample data can be prohibitive.
    There is a secondary but different cost which attaches to the
    “construction” of novel data. Principal Components Analysis can be
    seen as a technique to optimally reconstruct a complex multivariate
    data set from a lower level compressed dimensional space.
    Variational auto-encoders allow us to achieve yet more flexible
    reconstruction results in non-linear cases. Drawing a new sample
    from the posterior predictive distribution of Bayesian models
    similarly supplies us with insight in the variability of realised
    data. Both methods assume a latent model of the data generating
    process that aims to leverage a compressed representation of the
    data. These are different heuristics with different consequences for
    how we understand the variability in the world. Amortized Bayesian
    inference seeks to unite the two heuristics.}
}
For attribution, please cite this work as:
Forde, Nathaniel. 2025. “Amortized Bayesian Inference with PyTorch.” July 25, 2025.